{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tabulate import tabulate\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "execution_plans_small = [\"DP\", \"GC\",\"ZeRO-DP+GA\",\"ZeRO-Offload\"]\n",
    "execution_plans_large = [\"TP+PP\", \"DP+TP+PP\",\"ZeRO-DP+GA\",\"ZeRO-Offload+GC\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def draw_table(model_prediction,execution_plans):\n",
    "    validation_data = []\n",
    "    for execution_plan in execution_plans:\n",
    "        if execution_plan not in model_prediction:\n",
    "            validation_data.append([execution_plan,\n",
    "                                \"/\",\n",
    "                                \"/\"])\n",
    "        else:\n",
    "            validation_data.append([execution_plan,\n",
    "                                round(np.average(model_prediction[execution_plan])*100,3),\n",
    "                                round(np.max(model_prediction[execution_plan])*100,3)])\n",
    "\n",
    "    headers = [\"execution_plan\", \"avg.\", \"max.\"]\n",
    "    table = tabulate(validation_data, headers, tablefmt=\"grid\")\n",
    "    return table\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ViT Validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------------------+--------+--------+\n",
      "| execution_plan   |   avg. |   max. |\n",
      "+==================+========+========+\n",
      "| DP               |  3.631 |  6.828 |\n",
      "+------------------+--------+--------+\n",
      "| GC               |  2.592 |  6.186 |\n",
      "+------------------+--------+--------+\n",
      "| ZeRO-DP+GA       |  4.231 |  6.669 |\n",
      "+------------------+--------+--------+\n",
      "| ZeRO-Offload     |  3.002 |  8.315 |\n",
      "+------------------+--------+--------+\n"
     ]
    }
   ],
   "source": [
    "from model.validate_vit import fit_ViT, validate\n",
    "vit_fittable_params = fit_ViT()\n",
    "vit_prediction = validate(vit_fittable_params)\n",
    "table=draw_table(vit_prediction,execution_plans_small)\n",
    "print(table)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## RoBERTa Validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------------------+--------+--------+\n",
      "| execution_plan   |   avg. |   max. |\n",
      "+==================+========+========+\n",
      "| DP               |  2.209 |  4.365 |\n",
      "+------------------+--------+--------+\n",
      "| GC               |  3.358 |  4.292 |\n",
      "+------------------+--------+--------+\n",
      "| ZeRO-DP+GA       |  3.593 |  6.712 |\n",
      "+------------------+--------+--------+\n",
      "| ZeRO-Offload     |  7.418 | 10.436 |\n",
      "+------------------+--------+--------+\n"
     ]
    }
   ],
   "source": [
    "from model.validate_roberta import fit_RoBERTa, validate\n",
    "roberta_fittable_params = fit_RoBERTa()\n",
    "roberta_prediction = validate(roberta_fittable_params)\n",
    "table=draw_table(roberta_prediction,execution_plans_small)\n",
    "print(table)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## BERT Validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------------------+--------+--------+\n",
      "| execution_plan   |   avg. |   max. |\n",
      "+==================+========+========+\n",
      "| DP               |  5.268 |  8.32  |\n",
      "+------------------+--------+--------+\n",
      "| GC               |  4.9   |  7.27  |\n",
      "+------------------+--------+--------+\n",
      "| ZeRO-DP+GA       |  3.701 |  6.899 |\n",
      "+------------------+--------+--------+\n",
      "| ZeRO-Offload     |  6.372 |  8.618 |\n",
      "+------------------+--------+--------+\n"
     ]
    }
   ],
   "source": [
    "from model.validate_bert import fit_BERT, validate\n",
    "bert_fittable_params = fit_BERT()\n",
    "bert_prediction = validate(bert_fittable_params)\n",
    "table=draw_table(bert_prediction,execution_plans_small)\n",
    "print(table)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## T5 Validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------------------+--------+--------+\n",
      "| execution_plan   |   avg. |   max. |\n",
      "+==================+========+========+\n",
      "| TP+PP            |  3.184 |  8.239 |\n",
      "+------------------+--------+--------+\n",
      "| DP+TP+PP         |  2.406 |  9.55  |\n",
      "+------------------+--------+--------+\n",
      "| ZeRO-DP+GA       |  6.711 |  9.554 |\n",
      "+------------------+--------+--------+\n",
      "| ZeRO-Offload+GC  |  4.368 |  6.335 |\n",
      "+------------------+--------+--------+\n"
     ]
    }
   ],
   "source": [
    "from model.validation_t5 import fit_T5, validate\n",
    "t5_fittable_params = fit_T5()\n",
    "t5_prediction = validate(t5_fittable_params)\n",
    "table=draw_table(t5_prediction,execution_plans_large)\n",
    "print(table)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## GPT-2 Validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------------------+--------+--------+\n",
      "| execution_plan   |   avg. |   max. |\n",
      "+==================+========+========+\n",
      "| TP+PP            |  2.39  |  3.081 |\n",
      "+------------------+--------+--------+\n",
      "| DP+TP+PP         |  2.8   |  4.151 |\n",
      "+------------------+--------+--------+\n",
      "| ZeRO-DP+GA       |  2.516 |  3.855 |\n",
      "+------------------+--------+--------+\n",
      "| ZeRO-Offload+GC  |  5.521 |  8.901 |\n",
      "+------------------+--------+--------+\n"
     ]
    }
   ],
   "source": [
    "from model.validate_gpt import fit_GPT2, validate\n",
    "gpt2_fittable_params = fit_GPT2()\n",
    "gpt2_prediction = validate(gpt2_fittable_params)\n",
    "table=draw_table(gpt2_prediction,execution_plans_large)\n",
    "print(table)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LLaMA-2-7B Validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------------------+--------+--------+\n",
      "| execution_plan   | avg.   | max.   |\n",
      "+==================+========+========+\n",
      "| TP+PP            | 1.898  | 2.898  |\n",
      "+------------------+--------+--------+\n",
      "| DP+TP+PP         | 4.699  | 9.453  |\n",
      "+------------------+--------+--------+\n",
      "| ZeRO-DP+GA       | /      | /      |\n",
      "+------------------+--------+--------+\n",
      "| ZeRO-Offload+GC  | 4.094  | 6.382  |\n",
      "+------------------+--------+--------+\n"
     ]
    }
   ],
   "source": [
    "from model.validate_llama7B import fit_LLaMA7B, validate\n",
    "llama7b_fittable_params = fit_LLaMA7B()\n",
    "llama7b_prediction = validate(llama7b_fittable_params)\n",
    "table=draw_table(llama7b_prediction,execution_plans_large)\n",
    "print(table)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LLaMA-30B Validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------------------+--------+--------+\n",
      "| execution_plan   | avg.   | max.   |\n",
      "+==================+========+========+\n",
      "| TP+PP            | 4.291  | 8.524  |\n",
      "+------------------+--------+--------+\n",
      "| DP+TP+PP         | 6.149  | 9.69   |\n",
      "+------------------+--------+--------+\n",
      "| ZeRO-DP+GA       | /      | /      |\n",
      "+------------------+--------+--------+\n",
      "| ZeRO-Offload+GC  | /      | /      |\n",
      "+------------------+--------+--------+\n"
     ]
    }
   ],
   "source": [
    "from model.validate_llama30B import fit_LLaMA30B, validate\n",
    "llama30b_fittable_params = fit_LLaMA30B()\n",
    "llama30b_prediction = validate(llama30b_fittable_params)\n",
    "table=draw_table(llama30b_prediction,execution_plans_large)\n",
    "print(table)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rubick",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
